# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Send/Recv ops.

The following _Send()/_Recv() are adapted from python op wrappers
generated by python_op_gen_main. python_op_gen_main.cc's
PrintAllPythonOps needs to be updated to export internal ops.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from google.protobuf import text_format as _text_format
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import tensor_shape as _tensor_shape


def _Recv(tensor_type, tensor_name, send_device, recv_device, name=None):
  r"""Receives the named tensor from send_device on recv_device.

  Args:
    tensor_type: A `tf.DType`.
    tensor_name: A `string`. The name of the tensor to receive.
    send_device: A `string`. The name of the device sending the tensor.
    recv_device: A `string`. The name of the device receiving the tensor.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `tensor_type`. The tensor to receive.
  """
  result = _op_def_lib.apply_op(
      "_Recv",
      tensor_type=tensor_type,
      tensor_name=tensor_name,
      send_device=send_device,
      send_device_incarnation=0,
      recv_device=recv_device,
      client_terminated=False,
      name=name if name else "Recv")
  return result


_ops.RegisterShape("_Recv")(None)


def _Send(tensor, tensor_name, send_device, recv_device, name=None):
  r"""Sends the named tensor from send_device to recv_device.

  Args:
    tensor: A `Tensor`. The tensor to send.
    tensor_name: A `string`. The name of the tensor to send.
    send_device: A `string`. The name of the device sending the tensor.
    recv_device: A `string`. The name of the device receiving the tensor.
    name: A name for the operation (optional).

  Returns:
    The created Operation.
  """
  result = _op_def_lib.apply_op(
      "_Send",
      tensor=tensor,
      tensor_name=tensor_name,
      send_device=send_device,
      send_device_incarnation=0,
      recv_device=recv_device,
      client_terminated=False,
      name=name if name else "Send")
  return result


_ops.RegisterShape("_Send")(None)


def _InitOpDefLibrary():
  op_list = _op_def_pb2.OpList()
  _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
  _op_def_registry.register_op_list(op_list)
  op_def_lib = _op_def_library.OpDefLibrary()
  op_def_lib.add_op_list(op_list)
  return op_def_lib


_InitOpDefLibrary.op_list_ascii = """op {
  name: "_Recv"
  output_arg {
    name: "tensor"
    type_attr: "tensor_type"
  }
  attr {
    name: "tensor_type"
    type: "type"
  }
  attr {
    name: "tensor_name"
    type: "string"
  }
  attr {
    name: "send_device"
    type: "string"
  }
  attr {
    name: "send_device_incarnation"
    type: "int"
  }
  attr {
    name: "recv_device"
    type: "string"
  }
  attr {
    name: "client_terminated"
    type: "bool"
    default_value {
      b: false
    }
  }
  is_stateful: true
}
op {
  name: "_Send"
  input_arg {
    name: "tensor"
    type_attr: "T"
  }
  attr {
    name: "T"
    type: "type"
  }
  attr {
    name: "tensor_name"
    type: "string"
  }
  attr {
    name: "send_device"
    type: "string"
  }
  attr {
    name: "send_device_incarnation"
    type: "int"
  }
  attr {
    name: "recv_device"
    type: "string"
  }
  attr {
    name: "client_terminated"
    type: "bool"
    default_value {
      b: false
    }
  }
  is_stateful: true
}
op {
  name: "_XLARecv"
  output_arg {
    name: "tensor"
    type_attr: "T"
  }
  attr {
    name: "T"
    type: "type"
  }
  attr {
    name: "tensor_name"
    type: "string"
  }
  attr {
    name: "shape"
    type: "shape"
  }
  is_stateful: true
}
op {
  name: "_XLASend"
  input_arg {
    name: "tensor"
    type_attr: "T"
  }
  attr {
    name: "T"
    type: "type"
  }
  attr {
    name: "tensor_name"
    type: "string"
  }
  is_stateful: true
}
"""

_op_def_lib = _InitOpDefLibrary()


def _TpuCore(device):
  """Returns the TPU core represented by <device>, or -1 if not TPU."""
  prefix = "device:TPU_REPLICATED_CORE:"
  if prefix in device:
    return int(device[len(prefix):])
  return -1


class Channel(object):
  """A communication channel to transfer tensors in order."""

  def __init__(self, dtype, shape, send_device, recv_device, name=None):
    """Construct a channel.

    Args:
      dtype: The dtype of tensors sent through the channel.
      shape: The shape of tensors sent through the channel. Must be a fully
      defined shape for TPUs.
      send_device: A fully-specified tensorflow device.
      recv_device: A fully-specified tensorflow device.
      name: A name for the channel (optional).
    """
    current_graph = _ops.get_default_graph()
    assert current_graph, "A channel is scoped within a tf.Graph"
    self._dtype = dtype
    self._send_device = send_device
    self._recv_device = recv_device
    self._name = current_graph.unique_name(name if name else "channel")

    assert shape is not None
    shape = _tensor_shape.TensorShape(shape)

    self._shape = shape
    self._send_tpu_core = _TpuCore(send_device)
    self._recv_tpu_core = _TpuCore(recv_device)
    self._send_called = False
    self._recv_op = None
    assert ((self._send_tpu_core == -1) == (self._recv_tpu_core == -1)), (
        "Mixing TPU and non-TPU: %s and %s" % (send_device, recv_device))
    if self._send_tpu_core >= 0:
      assert self._shape.is_fully_defined(), (
          "TPU channel must have fully defined shape. Name: %s, shape: %s" %
          (self._name, self._shape))
      assert self._send_tpu_core != self._recv_tpu_core, (
          "TPU send/recv must be cross-core: %s and %s" %
          (send_device, recv_device))

  def TpuSend(self, tensor):
    raise NotImplementedError("TPU send op not implemented.")

  def Send(self, tensor):
    """Sends a tensor through the channel."""
    assert tensor.dtype == self._dtype
    assert not self._send_called, (
        "Send called multiple times for %s" % self._name)
    self._send_called = True
    if self._send_tpu_core == -1:
      return _Send(tensor, self._name, self._send_device, self._recv_device)
    else:
      with _ops.device(self._send_device):
        return self.TpuSend(tensor)

  def TpuRecv(self):
    raise NotImplementedError("TPU recv op not implemented.")

  def Recv(self):
    """Receives a tensor from the channel."""
    if self._send_tpu_core == -1:
      return _Recv(self._dtype, self._name,
                   self._send_device, self._recv_device)
    else:
      with _ops.device(self._recv_device):
        return self.TpuRecv()
